{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vwELCooy4ljr"
      },
      "source": [
        "# Prompt-driven search with LLMs\n",
        "\n",
        "This notebook revisits the Extractor pipeline, which has been covered in a number of previous notebooks. This pipeline is a combination of a similarity instance (embeddings or similarity pipeline) to build a question context and a model that answers questions.\n",
        "\n",
        "The Extractor pipeline recently underwent a number of major upgrades to support the following.\n",
        "\n",
        "- Ability to run embeddings searches. Given that content is supported, text can be retrieved from the embeddings instance.\n",
        "- In addition to extractive qa, support text generation models, sequence to sequence models and custom pipelines\n",
        "\n",
        "These changes enable embeddings-guided and prompt-driven search with Large Language Models (LLMs) 🔥🔥🔥"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ew7orE2O441o"
      },
      "source": [
        "# Install dependencies\n",
        "\n",
        "Install `txtai` and all dependencies."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LPQTb25tASIG"
      },
      "source": [
        "%%capture\n",
        "!pip install git+https://github.com/neuml/txtai datasets"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_YnqorRKAbLu"
      },
      "source": [
        "# Create Embeddings and Extractor instances\n",
        "\n",
        "An Embeddings instance defines methods to represent text as vectors and build vector indexes for search.\n",
        "\n",
        "The Extractor pipeline is a combination of a similarity instance (embeddings or similarity pipeline) to build a question context and a model that answers questions. The model can be a prompt-driven large language model (LLM), an extractive question-answering model or a custom pipeline.\n",
        "\n",
        "Let's run a basic example.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OUc9gqTyAYnm"
      },
      "source": [
        "%%capture\n",
        "\n",
        "from txtai.embeddings import Embeddings\n",
        "from txtai.pipeline import Extractor\n",
        "\n",
        "# Create embeddings model with content support\n",
        "embeddings = Embeddings({\"path\": \"sentence-transformers/all-MiniLM-L6-v2\", \"content\": True})\n",
        "\n",
        "# Create extractor instance\n",
        "extractor = Extractor(embeddings, \"google/flan-t5-base\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4X5z3UjnAGe7",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "cacb7e9d-471a-437d-c68d-a8b51f876413"
      },
      "source": [
        "data = [\"Giants hit 3 HRs to down Dodgers\",\n",
        "        \"Giants 5 Dodgers 4 final\",\n",
        "        \"Dodgers drop Game 2 against the Giants, 5-4\",\n",
        "        \"Blue Jays beat Red Sox final score 2-1\",\n",
        "        \"Red Sox lost to the Blue Jays, 2-1\",\n",
        "        \"Blue Jays at Red Sox is over. Score: 2-1\",\n",
        "        \"Phillies win over the Braves, 5-0\",\n",
        "        \"Phillies 5 Braves 0 final\",\n",
        "        \"Final: Braves lose to the Phillies in the series opener, 5-0\",\n",
        "        \"Lightning goaltender pulled, lose to Flyers 4-1\",\n",
        "        \"Flyers 4 Lightning 1 final\",\n",
        "        \"Flyers win 4-1\"]\n",
        "\n",
        "def prompt(question):\n",
        "  return f\"\"\"\n",
        "    Answer the following question using the context below.\n",
        "    Question: {question}\n",
        "    Context:\n",
        "  \"\"\"\n",
        "\n",
        "questions = [\"What team won the game?\", \"What was score?\"]\n",
        "\n",
        "execute = lambda query: extractor([(question, query, prompt(question), False) for question in questions], data)\n",
        "\n",
        "for query in [\"Red Sox - Blue Jays\", \"Phillies - Braves\", \"Dodgers - Giants\", \"Flyers - Lightning\"]:\n",
        "    print(\"----\", query, \"----\")\n",
        "    for answer in execute(query):\n",
        "        print(answer)\n",
        "    print()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "---- Red Sox - Blue Jays ----\n",
            "('What team won the game?', 'Blue Jays')\n",
            "('What was score?', '2-1')\n",
            "\n",
            "---- Phillies - Braves ----\n",
            "('What team won the game?', 'Phillies')\n",
            "('What was score?', '5-0')\n",
            "\n",
            "---- Dodgers - Giants ----\n",
            "('What team won the game?', 'Giants')\n",
            "('What was score?', '5-4')\n",
            "\n",
            "---- Flyers - Lightning ----\n",
            "('What team won the game?', 'Flyers')\n",
            "('What was score?', '4-1')\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "This code runs a series of questions. First it runs an embeddings filtering query to find the most relevant text. For example, `Red Sox - Blue Jays` finds text related to those teams. Then `What team won the game?` and `What was the score?` are asked.\n",
        "\n",
        "This logic is the same logic found in Notebook 5 - Extractive QA with txtai but uses prompt-based QA vs extractive QA. "
      ],
      "metadata": {
        "id": "7AnPvSeM3N1Z"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Embeddings-guided and Prompt-driven Search\n",
        "\n",
        "Now for the fun stuff. Let's build an embeddings index for the `ag_news` dataset (a set of news stories from the mid 2000s). Then we'll use prompts to ask questions with embeddings results as the context."
      ],
      "metadata": {
        "id": "Aj8GoDk331cS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "dataset = load_dataset(\"ag_news\", split=\"train\")\n",
        "\n",
        "# List of all text elements\n",
        "texts = dataset[\"text\"]\n",
        "\n",
        "# Create an embeddings index over the dataset\n",
        "embeddings = Embeddings({\"path\": \"sentence-transformers/all-MiniLM-L6-v2\", \"content\": True})\n",
        "embeddings.index((x, text, None) for x, text in enumerate(texts))\n",
        "\n",
        "# Create extractor instance\n",
        "extractor = Extractor(embeddings, \"google/flan-t5-large\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yL716oEZ43t-",
        "outputId": "23f4b0e7-a60a-4e89-fb57-06966e6612f8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:datasets.builder:Found cached dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now let's run a prompt-driven search!"
      ],
      "metadata": {
        "id": "Ifl8JwLDBL7k"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def prompt(question):\n",
        "  return f\"\"\"Answer the following question using only the context below. Say 'no answer' when the question can't be answered.\n",
        "Question: {question}\n",
        "Context: \"\"\"\n",
        "\n",
        "def search(query, question=None):\n",
        "  # Default question to query if empty\n",
        "  if not question:\n",
        "    question = query\n",
        "\n",
        "  return extractor([(\"answer\", query, prompt(question), False)])[0][1]\n",
        "\n",
        "question = \"Who won the 2004 presidential election?\"\n",
        "answer = search(question)\n",
        "print(question, answer)\n",
        "\n",
        "nquestion = \"Who did the candidate beat?\"\n",
        "print(nquestion, search(f\"{question} {answer}. {nquestion}\"))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5O1WBJ8153Mo",
        "outputId": "ddf09da2-7d4c-4fd3-b0da-df5631bacd13"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Who won the 2004 presidential election? George W. Bush\n",
            "Who did the candidate beat? John F. Kerry\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "And there are the answers. Let's unpack how this works.\n",
        "\n",
        "The first thing the Extractor pipeline does is run an embeddings search to find the most relevant text within the index. A context string is then built using those search results.\n",
        "\n",
        "After that, a prompt is generated, run and the answer printed. Let's see what a full prompt looks like."
      ],
      "metadata": {
        "id": "AhViFXH_BZSo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "text = prompt(question)\n",
        "text += \"\\n\" + \"\\n\".join(x[\"text\"]for x in embeddings.search(question))\n",
        "print(text)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h7BsoJC3CFOq",
        "outputId": "44ae0dec-4f79-4d44-88f5-585e5eae9cb2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Answer the following question using only the context below. Say 'no answer' when the question can't be answered.\n",
            "Question: Who won the 2004 presidential election?\n",
            "Context: \n",
            "Right- and left-click politics The 2004 presidential race ended last week in a stunning defeat for Massachusetts Senator John F. Kerry, as incumbent President George W. Bush cruised to an easy victory.\n",
            "2004 Presidential Endorsements (AP) AP - Newspaper endorsements in the 2004 presidential campaign between President Bush, a Republican, and Sen. John Kerry, a Democrat.\n",
            "Presidential Campaign to Nov. 2, 2004 (Reuters) Reuters - The following diary of events\\leading up to the presidential election on Nov. 2.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "The prompt has the information needed to determine the answers to the questions."
      ],
      "metadata": {
        "id": "4m9XketCty3m"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Additional examples\n",
        "\n",
        "Before moving on, a couple more example questions."
      ],
      "metadata": {
        "id": "JtDcVPdOB0Rv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "question = \"Who won the World Series in 2004?\"\n",
        "answer = search(question)\n",
        "print(question, answer)\n",
        "\n",
        "nquestion = \"Who did they beat?\"\n",
        "print(nquestion, search(f\"{question} {answer}. {nquestion}\"))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0NNLBwC-83MM",
        "outputId": "4b9fabd7-baf9-4f7d-bb8b-c9c60c5101e4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Who won the World Series in 2004? Boston\n",
            "Who did they beat? St Louis\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "search(\"Tell me something interesting?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        },
        "id": "1P0zqkTW9cZW",
        "outputId": "4f2232db-761b-464f-b47d-6695c84ffb80"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'herrings communicate by farting'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Whhaaaattt???  Is this a model hallucination?\n",
        "\n",
        "Let's run an embeddings query and see if that text is in the results."
      ],
      "metadata": {
        "id": "ygFFcwWPGI9p"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "answer = \"herrings communicate by farting\"\n",
        "for x in embeddings.search(\"Tell me something interesting?\"):\n",
        "  if answer in x[\"text\"]:\n",
        "    start = x[\"text\"].find(answer)\n",
        "    print(x[\"text\"][start:start + len(answer)])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qZPhLqSxGMbK",
        "outputId": "e3b2909a-1ee8-480f-afa3-201a8e27cb08"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "herrings communicate by farting\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Sure enough it is. It appears that the FLAN-T5 model has a bit of an immature sense of humor 😃"
      ],
      "metadata": {
        "id": "IpgxMc1DZcds"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# External API Integration\n",
        "\n",
        "In addition to support for Hugging Face models, the Extractor pipeline also supports custom question-answer models. This could be a call to the OpenAI API (GPT-3), Cohere API, Hugging Face API or using langchain to manage that. All that is needed is a Callable object or a function!\n",
        "\n",
        "Let's see an example that uses the Hugging Face API to answer questions. We'll use the original sports dataset to demonstrate."
      ],
      "metadata": {
        "id": "7bitjzeyYOqh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import requests\n",
        "\n",
        "data = [\"Giants hit 3 HRs to down Dodgers\",\n",
        "        \"Giants 5 Dodgers 4 final\",\n",
        "        \"Dodgers drop Game 2 against the Giants, 5-4\",\n",
        "        \"Blue Jays beat Red Sox final score 2-1\",\n",
        "        \"Red Sox lost to the Blue Jays, 2-1\",\n",
        "        \"Blue Jays at Red Sox is over. Score: 2-1\",\n",
        "        \"Phillies win over the Braves, 5-0\",\n",
        "        \"Phillies 5 Braves 0 final\",\n",
        "        \"Final: Braves lose to the Phillies in the series opener, 5-0\",\n",
        "        \"Lightning goaltender pulled, lose to Flyers 4-1\",\n",
        "        \"Flyers 4 Lightning 1 final\",\n",
        "        \"Flyers win 4-1\"]\n",
        "\n",
        "def prompt(question):\n",
        "  return f\"\"\"\n",
        "    Answer the following question using the context below.\n",
        "    Question: {question}\n",
        "    Context:\n",
        "  \"\"\"\n",
        "\n",
        "# Submits a series of prompts to the Hugging Face API.\n",
        "# This call can easily be switched to use the OpenAI API (GPT-3), Cohere API or a library like langchain.\n",
        "def api(prompts):\n",
        "  response = requests.post(\"https://api-inference.huggingface.co/models/google/flan-t5-base\",\n",
        "                           json={\"inputs\": prompts})\n",
        "\n",
        "  return [x[\"generated_text\"] for x in response.json()]\n",
        "\n",
        "# Create embeddings model with content support\n",
        "embeddings = Embeddings({\"path\": \"sentence-transformers/all-MiniLM-L6-v2\", \"content\": True})\n",
        "\n",
        "# Create extractor instance, submit prompts to the Hugging Face inference API\n",
        "extractor = Extractor(embeddings, api)\n",
        "\n",
        "questions = [\"What team won the game?\", \"What was score?\"]\n",
        "\n",
        "execute = lambda query: extractor([(question, query, prompt(question), False) for question in questions], data)\n",
        "\n",
        "for query in [\"Red Sox - Blue Jays\", \"Phillies - Braves\", \"Dodgers - Giants\", \"Flyers - Lightning\"]:\n",
        "    print(\"----\", query, \"----\")\n",
        "    for answer in execute(query):\n",
        "        print(answer)\n",
        "    print()\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yYAa9N7EYm7E",
        "outputId": "72208327-bfb7-466f-c664-3630f207b6dd"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "---- Red Sox - Blue Jays ----\n",
            "('What team won the game?', 'Blue Jays')\n",
            "('What was score?', '2-1')\n",
            "\n",
            "---- Phillies - Braves ----\n",
            "('What team won the game?', 'Phillies')\n",
            "('What was score?', '5-0')\n",
            "\n",
            "---- Dodgers - Giants ----\n",
            "('What team won the game?', 'Giants')\n",
            "('What was score?', '5-4')\n",
            "\n",
            "---- Flyers - Lightning ----\n",
            "('What team won the game?', 'Flyers')\n",
            "('What was score?', '4-1')\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Everything matches with first example above in [Create Embeddings and Extractor instances](#Create-Embeddings-and-Extractor-instances) except the prompts are run as an external API call.\n",
        "\n",
        "The Embeddings instance can also swap out the vectorization, database and vector store components with external API services. Check out the [txtai documentation](https://neuml.github.io/txtai/embeddings/configuration/) documentation for more information.\n",
        "\n"
      ],
      "metadata": {
        "id": "Xh8UhC3NdsGf"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Wrapping up\n",
        "\n",
        "This notebook covered how to run embeddings-guided and prompt-driven search with LLMs. This functionality is a major step forward towards `Generative Semantic Search` for txtai. More to come, stay tuned!"
      ],
      "metadata": {
        "id": "KqfvCXp2B3li"
      }
    }
  ]
}